from runner import Runner
from common.arguments import get_args
from common.env_wrappers import SubprocVecEnv
from common.helper import Logger, set_all_seeds
import pprint
import torch
import sys
import os

def make_env(args):
    import sys
    sys.path.append(
        os.path.abspath(
            os.path.join(os.path.dirname(sys.modules[__name__].__file__), "..")
        )
    )

    if args.scenario_name in ['GuessingNumber']:
        if args.scenario_name == 'GuessingNumber':
            from Env.guessing_number import GuessingNumber
        else:
            raise NotImplementedError

        env = GuessingNumber()

        args.n_agents = env.num_agents
        args.obs_shape = env.observation_space # 每一维代表该agent的obs维度

        args.action_shape = env.action_space_noop

        def get_env_fn(rank):
            def init_env():
                env1 = GuessingNumber()
                set_all_seeds(args.seed + rank * 12345)
                return env1
            return init_env

        if args.vec_env > 1:
            return SubprocVecEnv([get_env_fn(i) for i in range(args.vec_env)]), args
        else:
            return env, args

    elif args.scenario_name == 'RevealingGoal':
        from Env.revealing_goal import RevealingGoal

        env = RevealingGoal()

        # env = MultiAgentEnv(world)
        # args.n_players = env.n  # 包含敌人的所有玩家个数
        args.n_agents = env.num_agents
        args.obs_shape = env.observation_space  # 每一维代表该agent的obs维度

        args.action_shape = env.action_space   # 每一维代表该agent的act维度

        def get_env_fn(rank):
            def init_env():
                env1 = RevealingGoal()
                set_all_seeds(args.seed + rank * 12345)
                return env1

            return init_env

        if args.vec_env > 1:
            return SubprocVecEnv([get_env_fn(i) for i in range(args.vec_env)]), args
        else:
            return env, args
    else:
        raise NotImplementedError


if __name__ == '__main__':
    # get the params
    args = get_args()

    torch.backends.cudnn.benchmark = True

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    logger_path = os.path.join(args.save_dir, "train.log")
    sys.stdout = Logger(logger_path)

    set_all_seeds(args.seed)

    env, args = make_env(args)
    runner = Runner(args, env)

    pprint.pprint(vars(args))

    if args.evaluate:
        returns = runner.evaluate()
        print('Average returns is', returns)
    else:
        runner.run()